from __future__ import print_function
import torch
import numpy as np
from PIL import Image

import math
import random
import pandas as pd
import cv2
import scipy.ndimage as ndimage
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float



class IBUG_300W(torch.utils.data.Dataset):
    def __init__(self, data_path, is_train, random_scale, random_flip, random_rotation):
        self.is_train = is_train
        self.data_root = data_path
        self.scale_factor = 0.05
        self.rot_factor = 0
        self.random_scale = random_scale
        self.random_flip = random_flip
        self.random_rotation = random_rotation
        
        if is_train:
            self.df = pd.read_csv(data_path + "300W/300W_train_data.csv")
        else:
            self.df = pd.read_csv(data_path + "300W/300W_test_data_full.csv")
        
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)

        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_path = os.path.join(self.data_root, row['image_path'])
        img = np.array(Image.open(img_path).convert('RGB'), dtype=np.float32)
        
        # 300W dataset : 58 landmarks. [[x0, y0], [x1, y1], ..., [x67, y67]] order.
        pts = np.array(row[1:]).reshape(-1, 2)
        
        xmin = np.min(pts[:, 0])
        xmax = np.max(pts[:, 0])
        ymin = np.min(pts[:, 1])
        ymax = np.max(pts[:, 1])
        center_w = (math.floor(xmin) + math.ceil(xmax)) / 2.0
        center_h = (math.floor(ymin) + math.ceil(ymax)) / 2.0
        center = torch.Tensor([center_w, center_h])   # x, y order
        
        scale = max(math.ceil(xmax) - math.floor(xmin), math.ceil(ymax) - math.floor(ymin)) / 200.0
        scale *= 1.25
        
        if self.random_scale:
            scale = scale * (random.uniform(1 - self.scale_factor, 1 + self.scale_factor))
        
        r = 0
        if self.random_rotation : 
            r = random.uniform(-self.rot_factor, self.rot_factor) if random.random() <= 0.6 else 0
        
        if self.random_flip: 
            if random.random() <= 0.5:
                img = np.fliplr(img)
                pts = self.fliplr_joints(pts, width=img.shape[1])
                center[0] = img.shape[1] - center[0]
        
        img, scale_factor = self.crop(img, center, scale, [256, 256], rot=r)

        tpts = pts.copy()
        for i in range(pts.shape[0]):
            if tpts[i, 1] > 0 :
                tpts[i, 0:2] = self.transform_pixel(tpts[i, 0:2]+1, center, scale*scale_factor, [256,256], rot=r)
        
        img = (img - self.mean) / self.std
        img = img.transpose([2, 0, 1])
        img = torch.Tensor(img)
        
        # for [y1, x1, y2, x2, ...] order
        tpts = np.fliplr(tpts).flatten().astype(np.float32)
        tpts = torch.LongTensor(tpts)
        
        # [x1, y1, x2, y2, ...] order
        pts = torch.FloatTensor(pts.astype(np.float32)).flatten()
        
        return img, tpts, pts, center, scale


    def fliplr_joints(self, landmark_coordinates, width):
        matched_parts = [[1, 17], [2, 16], [3, 15], [4, 14], [5, 13], [6, 12], [7, 11], [8, 10],
                         [18, 27], [19, 26], [20, 25], [21, 24], [22, 23],
                         [32, 36], [33, 35],
                         [37, 46], [38, 45], [39, 44], [40, 43], [41, 48], [42, 47],
                         [49, 55], [50, 54], [51, 53], [62, 64], [61, 65], [68, 66], [59, 57], [60, 56]]
        
        landmark_coordinates[:, 0] = width - landmark_coordinates[:, 0]
        for pair in matched_parts:
            tmp = landmark_coordinates[pair[0] - 1, :].copy()
            landmark_coordinates[pair[0] - 1, :] = landmark_coordinates[pair[1] - 1, :]
            landmark_coordinates[pair[1] - 1, :] = tmp
        return landmark_coordinates


    def get_transform(self, center, scale, output_size, rot=0):
        h = 200 * scale
        t = np.zeros((3, 3))
        t[0, 0] = float(output_size[1]) / h
        t[1, 1] = float(output_size[0]) / h
        t[0, 2] = output_size[1] * (-float(center[0]) / h + .5)
        t[1, 2] = output_size[0] * (-float(center[1]) / h + .5)
        t[2, 2] = 1
        
        if not rot == 0:
            rot = -rot
            rot_mat = np.zeros((3, 3))
            rot_rad = rot * np.pi / 180
            sn, cs = np.sin(rot_rad), np.cos(rot_rad)
            rot_mat[0, :2] = [cs, -sn]
            rot_mat[1, :2] = [sn, cs]
            rot_mat[2, 2] = 1
            t_mat = np.eye(3)
            t_mat[0, 2] = -output_size[1]/2
            t_mat[1, 2] = -output_size[0]/2
            t_inv = t_mat.copy()
            t_inv[:2, 2] *= -1
            t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
        return t


    def transform_pixel(self, pt, center, scale, output_size, invert=0, rot=0):
        t = self.get_transform(center, scale, output_size, rot=rot)
        if invert:
            t = np.linalg.inv(t)
        new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
        new_pt = np.dot(t, new_pt)
        return new_pt[:2].astype(int) + 1


    def crop(self, img, center, scale, output_size, rot=0):
        center_new = center.clone()
        scale_adjustment = 1.
        
        ul = np.array(self.transform_pixel([0, 0], center_new, scale, output_size, invert=1))
        br = np.array(self.transform_pixel(output_size, center_new, scale, output_size, invert=1))
        original_size = (br-ul)[0]
        
        if not rot == 0:
            pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
            ul -= pad
            br += pad
        
        new_shape = [br[1] - ul[1], br[0] - ul[0]]
        if len(img.shape) > 2:
            new_shape += [img.shape[2]]
        new_img = np.zeros(new_shape, dtype=np.float32)
        new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
        new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
        old_x = max(0, ul[0]), min(len(img[0]), br[0])
        old_y = max(0, ul[1]), min(len(img), br[1])
        new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] / 255
        
        if not rot == 0:
            new_img = ndimage.rotate(new_img, rot)
            new_img = new_img[pad:-pad, pad:-pad]
            scale_adjustment = new_img.shape[0] / original_size
        
        new_img = cv2.resize(new_img.astype(np.float32), output_size)
        return new_img, scale_adjustment




def load_data(task, batch_size, random_scale, random_flip, random_rotation):
    path = './dataset/'
    train_set = IBUG_300W(path, True, random_scale, random_flip, random_rotation)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
    test_set = IBUG_300W(path, False, random_scale=False, random_flip=False, random_rotation=False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)
    return train_loader, test_loader



def load_state(train = False) : 
    path = "./pretrained_modules/"
    state = torch.load(path+"300W_state.pth")
    
    leye_state = {
        'f_net' : state['f_net_leye'],
        'c_net' : state['c_net'],
        'relative_c_net' : state['relative_c_net'],
        'preferred_z_f' : state['leye_z_f'],
        'preferred_z_c' : state['leye_z_c']}
    
    reye_state = {
        'f_net' : state['f_net_reye'],
        'c_net' : state['c_net'],
        'relative_c_net' : state['relative_c_net'],
        'preferred_z_f' : state['reye_z_f'],
        'preferred_z_c' : state['reye_z_c']}
    
    mouth_state = {
        'f_net' : state['f_net_mouth'],
        'c_net' : state['c_net'],
        'relative_c_net' : state['relative_c_net'],
        'preferred_z_f' : state['mouth_z_f'],
        'preferred_z_c' : state['mouth_z_c']}
    
    nose_state = {
        'f_net' : state['f_net_nose'],
        'c_net' : state['c_net'],
        'relative_c_net' : state['relative_c_net'],
        'preferred_z_f' : state['nose_z_f'],
        'preferred_z_c' : state['nose_z_c']}
    jaw_state = {
        'f_net' : state['f_net_jaw'],
        'c_net' : state['c_net'],
        'relative_c_net' : state['relative_c_net'],
        'preferred_z_f' : state['jaw_z_f'],
        'preferred_z_c' : state['jaw_z_c']}
    
    return leye_state, reye_state, mouth_state, nose_state, jaw_state




def list_mean(hist) : 
    return sum(hist) / len(hist)



def load_dirpolnet(dirpolnet_leye, dirpolnet_reye, 
                   dirpolnet_mouth, dirpolnet_nose, dirpolnet_jaw) : 
    path = "./pretrained_modules/"
    state = torch.load(path + "dirPolNet.pth", map_location=device)
    dirpolnet_leye.load_state_dict(state['habit_leye'])
    dirpolnet_reye.load_state_dict(state['habit_reye'])
    dirpolnet_mouth.load_state_dict(state['habit_mouth'])
    dirpolnet_nose.load_state_dict(state['habit_nose'])
    dirpolnet_jaw.load_state_dict(state['habit_jaw'])
    
    return dirpolnet_leye, dirpolnet_reye, dirpolnet_mouth, dirpolnet_nose, dirpolnet_jaw



def load_optuna_setting() : 
    path = "./pretrained_modules/"
    state = torch.load(path + "optuna_setting_300W.pth", map_location=device)
    lambda_control_start = state['lambda_control_start']
    lambda_decrease = state['lambda_decrease']
    lambda_ft_init = state['lambda_f_init']
    lambda_freq = state['lambda_freq']
    thr_control_start = state['thr_control_start']
    thr_increase = state['thr_increase']
    thr_init = state['thr_init']
    thr_freq = state['thr_freq']
    lambda_ft_1stage = state['lambda_f_1stage']
    lambda_ft_2stage = state['lambda_f_2stage']
    
    return lambda_control_start, lambda_decrease, lambda_ft_init, lambda_freq,\
        thr_control_start, thr_increase, thr_init, thr_freq, lambda_ft_1stage, lambda_ft_2stage



def abs_coord_to_norm(c, img_size):
    return (2 * c / (torch.FloatTensor(img_size)-1).to(device)) - 1



def norm_coord_to_abs(c, img_size) : 
    return torch.round((c + 1) * ((torch.FloatTensor(img_size)-1).to(device) / 2))



def KL_div_from_mean(mean1, mean2) : 
    return 0.5*((mean1-mean2)**2)



def get_transform(center, scale, output_size, rot=0):
    h = 200 * scale
    t = np.zeros((3, 3))
    t[0, 0] = float(output_size[1]) / h
    t[1, 1] = float(output_size[0]) / h
    t[0, 2] = output_size[1] * (-float(center[0]) / h + .5)
    t[1, 2] = output_size[0] * (-float(center[1]) / h + .5)
    t[2, 2] = 1
    
    if not rot == 0:
        rot = -rot
        rot_mat = np.zeros((3, 3))
        rot_rad = rot * np.pi / 180
        sn, cs = np.sin(rot_rad), np.cos(rot_rad)
        rot_mat[0, :2] = [cs, -sn]
        rot_mat[1, :2] = [sn, cs]
        rot_mat[2, 2] = 1
        t_mat = np.eye(3)
        t_mat[0, 2] = -output_size[1]/2
        t_mat[1, 2] = -output_size[0]/2
        t_inv = t_mat.copy()
        t_inv[:2, 2] *= -1
        t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
    return t


def transform_pixel(pt, center, scale, output_size, invert=0, rot=0):
    t = get_transform(center, scale, output_size, rot=rot)
    if invert:
        t = np.linalg.inv(t)
    new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
    new_pt = np.dot(t, new_pt)
    return new_pt[:2].astype(int) + 1



# for test : original space
def NME_calc(inferred_coord_abs_x_y, pts, center, scale) :
    pts = pts.view(-1, 68, 2)
    
    B = inferred_coord_abs_x_y.size(0)
    n_l = inferred_coord_abs_x_y.size(1)
    
    inferred_coord_abs_x_y = inferred_coord_abs_x_y.cpu().numpy()
    
    for b in range(B) :     
        for l in range(n_l) : 
            inferred_coord_abs_x_y[b, l] = transform_pixel(inferred_coord_abs_x_y[b, l]+1, center[b], scale[b], [256,256], invert=1, rot=0)
    
    inferred_coord_abs_x_y = inferred_coord_abs_x_y.reshape(-1, 68, 2)
    error = torch.norm(pts - inferred_coord_abs_x_y, dim=-1).mean(-1)
    
    inter_ocular_distance = torch.norm(pts[:, 45] - pts[:, 36], dim=-1)
    
    NME_ocular = error / inter_ocular_distance
    return NME_ocular



def NME_calc_landmarkwise(inferred_coord_abs_x_y, pts, center, scale) :
    pts = pts.view(-1, 68, 2)
    
    B = inferred_coord_abs_x_y.size(0)
    n_l = inferred_coord_abs_x_y.size(1)
    
    inferred_coord_abs_x_y = inferred_coord_abs_x_y.cpu().numpy()
    
    for b in range(B) :     
        for l in range(n_l) : 
            inferred_coord_abs_x_y[b, l] = transform_pixel(inferred_coord_abs_x_y[b, l]+1, center[b], scale[b], [256,256], invert=1, rot=0)
    
    inferred_coord_abs_x_y = inferred_coord_abs_x_y.reshape(-1, 68, 2)
    error = torch.norm(pts - inferred_coord_abs_x_y, dim=-1)  # [B, n_l]
    inter_ocular_distance = torch.norm(pts[:, 45] - pts[:, 36], dim=-1)
    NME_ocular = error / inter_ocular_distance.view(-1, 1)
    return NME_ocular
